查看原文
其他

通俗解释优化的线性感知机算法:Pocket PLA

红色石头 AI有道 2019-06-22


AI有道

一个有情怀的公众号



在上一篇文章:


【附源码】一看就懂的感知机算法PLA


我们详细介绍了线性感知机算法模型,并使用pyhon实例,验证了PLA的实际分类效果。下图是PLA实际的分类效果:



但是,文章最后我们提出了一个疑问,就是PLA只能解决线性可分的问题。对于数据本身不是线性可分的情况,又该如何解决呢?下面,我们就将对PLA进行优化,以解决更一般的线性不可分问题。


1

Pocket PLA是什么


首先,我们来看一下线性不可分的例子:



如上图所示,正负样本线性不可分,无法使用PLA算法进行分类,这时候需要对PLA进行优化。优化后的PCA的基本做法很简单,就是如果迭代更新后分类错误样本比前一次少,则更新权重系数 w ;没有减少则保持当前权重系数 w 不变。也就是说,可以把条件放松,即不苛求每个点都分类正确,而是容忍有错误点,取错误点的个数最少时的权重系数 w 。通常在有限的迭代次数里,都能保证得到最佳的分类线。


这种算法也被称为「口袋PLA」Pocket PLA。怎么理解呢?就好像我们在搜寻最佳分类直线的时候,随机选择错误点修正,修正后的直线放在口袋里,暂时作为最佳分类线。然后如果还有错误点,继续随机选择某个错误点修正,修正后的直线与口袋里的分类线比较,把分类错误点较少的分类线放入口袋。一直到迭代次数结束,这时候放在口袋里的一定是最佳分类线,虽然可能还有错误点存在,但已经是最少的了。


2

数据准备


该数据集包含了100个样本,正负样本各50,特征维度为2。


data = pd.read_csv('./data/data2.csv', header=None)
# 样本输入,维度(100,2)
X = data.iloc[:,:2].values
# 样本输出,维度(100,)
y = data.iloc[:,2].values


下面我们在二维平面上绘出正负样本的分布情况。


import matplotlib.pyplot as plt

plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.title('Original Data')
plt.show()



很明显,从图中可以看出,正类和负类样本并不是线性可分的。这时候,我们就需要使用Pocket PLA。


3

Pocket PLA代码实现


首先分别对两个特征进行归一化处理,即:


# 均值
u = np.mean(X, axis=0)
# 方差
v = np.std(X, axis=0)

X = (X - u) / v

# 作图
plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.title('Normalization data')
plt.show()



接下来对预测直线进行初始化,包括权重 w 初始化:


# X加上偏置项
X = np.hstack((np.ones((X.shape[0],1)), X))
# 权重初始化
w = np.random.randn(3,1)


整个迭代训练过程如下:


for i in range(100):
   s = np.dot(X, w)
   y_pred = np.ones_like(y)
   loc_n = np.where(s < 0)[0]
   y_pred[loc_n] = -1
   num_fault = len(np.where(y != y_pred)[0])
   
   if num_fault == 0:
       break
   else:
       r = np.random.choice(num_fault)        # 随机选择一个错误分类点
       t = np.where(y != y_pred)[0][r]
       w2 = w + y[t] * X[t, :].reshape((3,1))
       
       s = np.dot(X, w2)
       y_pred = np.ones_like(y)
       loc_n = np.where(s < 0)[0]
       y_pred[loc_n] = -1
       num_fault2 = len(np.where(y != y_pred)[0])
       if num_fault2 <num_fault:
           w = w2        # 犯的错误点更少,则更新w,否则w不变


其中,迭代次数为100次,每次迭代随机选择一个错误点进行修正,修正后的分类线错误率与之前的分类线比较,若错误率较低,则选择修正后的分类线。继续进行下一次迭代。


迭代完毕后,得到更新后的权重系数 w ,绘制此时的分类直线是什么样子。


# 直线第一个坐标(x1,y1)
x1 = -2
y1 = -1 / w[2] * (w[0] * 1 + w[1] * x1)
# 直线第二个坐标(x2,y2)
x2 = 2
y2 = -1 / w[2] * (w[0] * 1 + w[1] * x2)
# 作图
plt.scatter(X[:50, 1], X[:50, 2], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 1], X[50:, 2], color='red', marker='x', label='Negative')
plt.plot([x1,x2], [y1,y2],'r')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.show()



计算一下分类的正确率:


s = np.dot(X, w)
y_pred = np.ones_like(y)
loc_n = np.where(s < 0)[0]
y_pred[loc_n] = -1
accuracy = len(np.where(y == y_pred)[0]) / len(y)
print('accuracy: %.2f' % accuracy)


accuracy: 0.93


分类正确率达到了0.93。


4

PLA的损失函数分析


我们知道,任何一个机器学习问题包含三个方面:模型、策略、算法。从策略来说,无论是PLA还是Pocket PLA,使用的损失函数是统计误分类点的总数,即希望误分类点的总数越少越好,属于0-1损失函数「0-1 Loss Function」。但是,这样的损失函数不是参数 w 的连续可导函数。


从算法实现上来看,PLA每次不断修正错误点,修正公式为:



修正公式的推导,我们在上一篇文章中已经详细解释过。数学上可以证明,PLA算法是收敛的。


而对于分类问题,常见的损失函数一般为交叉熵损失函数「Cross Entropy Loss」。其表达式为:



交叉熵损失函数使用的梯度下降算法修正公式为:



对比起来,PLA和交叉熵损失函数的修正公式具有相似性,不同的是PLA没有引入学习因子η和梯度。


以上内容对比了PLA和一般分类问题在策略和算法上的差异性。其实,红色石头想说的是,抓住本质最为重要,知道了不同的策略和方法,搭配不同的机器学习模型,只要能解决实际问题,都是可以的。也就是说我完全可以使用平方误差来作为分类问题的策略,从理论上讲是可行的。千万不要在解决问题时,只固定一种思路。


5

总结


PCA是机器学习最简单的算法之一。PLA处理线性可分问题,优化的PLA解决线性不可分的问题。实际验证表明,一般的PLA处理线性可分及线性不可分问题都有不错的表现,即一般能得到最佳的分类直线。但是PLA过于简单,有其本身的局限性。


本文完整代码我已上传到GitHub上,需要的点击「阅读原文」自行获取。喜欢的话,不妨点个Star。


P.S. 有兴趣的读者朋友也可以看看李航的《统计学习方法》第二章关于PLA的介绍,其思路和做法与我说的有所不同,使用的损失函数是误分类点到超平面的距离,效果应该更好一些。


推荐阅读:

【1】精心整理 | 林轩田机器学习资源汇总

【2】干货 | 林轩田机器学习「基石+技法」历史文章汇总

【3】干货 | 吴恩达deeplearning.ai专项课程历史文章汇总

【4】简单的梯度下降算法,你真的懂了吗

【5】机器学习中的维度灾难



长按二维码,扫描关注!


Modified on

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存